{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2\n",
    "import time\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import matplotlib\n",
    "\n",
    "from random import randint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"gpu\" # please change it to \"gpu\" if the model needs to be run on cuda."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "protoFile = \"pose/coco/pose_deploy_linevec.prototxt\"\n",
    "weightsFile = \"pose/coco/pose_iter_440000.caffemodel\"\n",
    "nPoints = 18\n",
    "# COCO Output Format\n",
    "keypointsMapping = ['Nose', 'Neck', 'R-Sho', 'R-Elb', 'R-Wr', 'L-Sho', \n",
    "                    'L-Elb', 'L-Wr', 'R-Hip', 'R-Knee', 'R-Ank', 'L-Hip', \n",
    "                    'L-Knee', 'L-Ank', 'R-Eye', 'L-Eye', 'R-Ear', 'L-Ear']\n",
    "\n",
    "POSE_PAIRS = [[1,2], [1,5], [2,3], [3,4], [5,6], [6,7],\n",
    "              [1,8], [8,9], [9,10], [1,11], [11,12], [12,13],\n",
    "              [1,0], [0,14], [14,16], [0,15], [15,17],\n",
    "              [2,17], [5,16] ]\n",
    "\n",
    "# index of pafs correspoding to the POSE_PAIRS\n",
    "# e.g for POSE_PAIR(1,2), the PAFs are located at indices (31,32) of output, Similarly, (1,5) -> (39,40) and so on.\n",
    "mapIdx = [[31,32], [39,40], [33,34], [35,36], [41,42], [43,44], \n",
    "          [19,20], [21,22], [23,24], [25,26], [27,28], [29,30], \n",
    "          [47,48], [49,50], [53,54], [51,52], [55,56], \n",
    "          [37,38], [45,46]]\n",
    "\n",
    "colors = [ [0,100,255], [0,100,255], [0,255,255], [0,100,255], [0,255,255], [0,100,255],\n",
    "         [0,255,0], [255,200,100], [255,0,255], [0,255,0], [255,200,100], [255,0,255],\n",
    "         [0,0,255], [255,0,0], [200,200,0], [255,0,0], [200,200,0], [0,0,0]]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Find the Keypoints using Non Maximum Suppression on the Confidence Map\n",
    "def getKeypoints(probMap, threshold=0.1):\n",
    "    \n",
    "    mapSmooth = cv2.GaussianBlur(probMap,(3,3),0,0)\n",
    "\n",
    "    mapMask = np.uint8(mapSmooth>threshold)\n",
    "    keypoints = []\n",
    "    \n",
    "    #find the blobs\n",
    "    contours, _ = cv2.findContours(mapMask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)\n",
    "    \n",
    "    #for each blob find the maxima\n",
    "    for cnt in contours:\n",
    "        blobMask = np.zeros(mapMask.shape)\n",
    "        blobMask = cv2.fillConvexPoly(blobMask, cnt, 1)\n",
    "        maskedProbMap = mapSmooth * blobMask\n",
    "        _, maxVal, _, maxLoc = cv2.minMaxLoc(maskedProbMap)\n",
    "        keypoints.append(maxLoc + (probMap[maxLoc[1], maxLoc[0]],))\n",
    "\n",
    "    return keypoints\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Use the following equation for finding valid joint-pairs\n",
    "\n",
    "![](./pose-estimation-paf-equation.png)\n",
    "\n",
    "In the above equation:\n",
    "\n",
    "L is the PAF;\n",
    "\n",
    "d is the vector joining two joints;\n",
    "\n",
    "p is the interpolated point between two joints;\n",
    "\n",
    "It is implemented using the dot product between the PAF and the vector $d_{ij}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Find valid connections between the different joints of a all persons present\n",
    "def getValidPairs(output):\n",
    "    valid_pairs = []\n",
    "    invalid_pairs = []\n",
    "    n_interp_samples = 10\n",
    "    paf_score_th = 0.1\n",
    "    conf_th = 0.7\n",
    "    # loop for every POSE_PAIR\n",
    "    for k in range(len(mapIdx)):\n",
    "        # A->B constitute a limb\n",
    "        pafA = output[0, mapIdx[k][0], :, :]\n",
    "        pafB = output[0, mapIdx[k][1], :, :]\n",
    "        pafA = cv2.resize(pafA, (frameWidth, frameHeight))\n",
    "        pafB = cv2.resize(pafB, (frameWidth, frameHeight))\n",
    "\n",
    "        # Find the keypoints for the first and second limb\n",
    "        candA = detected_keypoints[POSE_PAIRS[k][0]]\n",
    "        candB = detected_keypoints[POSE_PAIRS[k][1]]\n",
    "        nA = len(candA)\n",
    "        nB = len(candB)\n",
    "\n",
    "        # If keypoints for the joint-pair is detected\n",
    "        # check every joint in candA with every joint in candB \n",
    "        # Calculate the distance vector between the two joints\n",
    "        # Find the PAF values at a set of interpolated points between the joints\n",
    "        # Use the above formula to compute a score to mark the connection valid\n",
    "        \n",
    "        if( nA != 0 and nB != 0):\n",
    "            valid_pair = np.zeros((0,3))\n",
    "            for i in range(nA):\n",
    "                max_j=-1\n",
    "                maxScore = -1\n",
    "                found = 0\n",
    "                for j in range(nB):\n",
    "                    # Find d_ij\n",
    "                    d_ij = np.subtract(candB[j][:2], candA[i][:2])\n",
    "                    norm = np.linalg.norm(d_ij)\n",
    "                    if norm:\n",
    "                        d_ij = d_ij / norm\n",
    "                    else:\n",
    "                        continue\n",
    "                    # Find p(u)\n",
    "                    interp_coord = list(zip(np.linspace(candA[i][0], candB[j][0], num=n_interp_samples),\n",
    "                                            np.linspace(candA[i][1], candB[j][1], num=n_interp_samples)))\n",
    "                    # Find L(p(u))\n",
    "                    paf_interp = []\n",
    "                    for k in range(len(interp_coord)):\n",
    "                        paf_interp.append([pafA[int(round(interp_coord[k][1])), int(round(interp_coord[k][0]))],\n",
    "                                           pafB[int(round(interp_coord[k][1])), int(round(interp_coord[k][0]))] ]) \n",
    "                    # Find E\n",
    "                    paf_scores = np.dot(paf_interp, d_ij)\n",
    "                    avg_paf_score = sum(paf_scores)/len(paf_scores)\n",
    "                    \n",
    "                    # Check if the connection is valid\n",
    "                    # If the fraction of interpolated vectors aligned with PAF is higher then threshold -> Valid Pair  \n",
    "                    if ( len(np.where(paf_scores > paf_score_th)[0]) / n_interp_samples ) > conf_th :\n",
    "                        if avg_paf_score > maxScore:\n",
    "                            max_j = j\n",
    "                            maxScore = avg_paf_score\n",
    "                            found = 1\n",
    "                # Append the connection to the list\n",
    "                if found:            \n",
    "                    valid_pair = np.append(valid_pair, [[candA[i][3], candB[max_j][3], maxScore]], axis=0)\n",
    "\n",
    "            # Append the detected connections to the global list\n",
    "            valid_pairs.append(valid_pair)\n",
    "        else: # If no keypoints are detected\n",
    "            print(\"No Connection : k = {}\".format(k))\n",
    "            invalid_pairs.append(k)\n",
    "            valid_pairs.append([])\n",
    "    print(valid_pairs)\n",
    "    return valid_pairs, invalid_pairs\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# This function creates a list of keypoints belonging to each person\n",
    "# For each detected valid pair, it assigns the joint(s) to a person\n",
    "# It finds the person and index at which the joint should be added. This can be done since we have an id for each joint\n",
    "def getPersonwiseKeypoints(valid_pairs, invalid_pairs):\n",
    "    # the last number in each row is the overall score \n",
    "    personwiseKeypoints = -1 * np.ones((0, 19))\n",
    "\n",
    "    for k in range(len(mapIdx)):\n",
    "        if k not in invalid_pairs:\n",
    "            partAs = valid_pairs[k][:,0]\n",
    "            partBs = valid_pairs[k][:,1]\n",
    "            indexA, indexB = np.array(POSE_PAIRS[k])\n",
    "\n",
    "            for i in range(len(valid_pairs[k])): \n",
    "                found = 0\n",
    "                person_idx = -1\n",
    "                for j in range(len(personwiseKeypoints)):\n",
    "                    if personwiseKeypoints[j][indexA] == partAs[i]:\n",
    "                        person_idx = j\n",
    "                        found = 1\n",
    "                        break\n",
    "\n",
    "                if found:\n",
    "                    personwiseKeypoints[person_idx][indexB] = partBs[i]\n",
    "                    personwiseKeypoints[person_idx][-1] += keypoints_list[partBs[i].astype(int), 2] + valid_pairs[k][i][2]\n",
    "\n",
    "                # if find no partA in the subset, create a new subset\n",
    "                elif not found and k < 17:\n",
    "                    row = -1 * np.ones(19)\n",
    "                    row[indexA] = partAs[i]\n",
    "                    row[indexB] = partBs[i]\n",
    "                    # add the keypoint_scores for the two keypoints and the paf_score \n",
    "                    row[-1] = sum(keypoints_list[valid_pairs[k][i,:2].astype(int), 2]) + valid_pairs[k][i][2]\n",
    "                    personwiseKeypoints = np.vstack([personwiseKeypoints, row])\n",
    "    return personwiseKeypoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image1 = cv2.imread(\"group.jpg\")\n",
    "frameWidth = image1.shape[1]\n",
    "frameHeight = image1.shape[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load the network and pass the image through the network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = time.time()\n",
    "net = cv2.dnn.readNetFromCaffe(protoFile, weightsFile)\n",
    "\n",
    "if device == \"cpu\":\n",
    "    net.setPreferableBackend(cv2.dnn.DNN_TARGET_CPU)\n",
    "    print(\"Using CPU device\")\n",
    "elif device == \"gpu\":\n",
    "    net.setPreferableBackend(cv2.dnn.DNN_BACKEND_CUDA)\n",
    "    net.setPreferableTarget(cv2.dnn.DNN_TARGET_CUDA)\n",
    "    print(\"Using GPU device\")\n",
    "\n",
    "# Fix the input Height and get the width according to the Aspect Ratio\n",
    "inHeight = 368\n",
    "inWidth = int((inHeight/frameHeight)*frameWidth)\n",
    "\n",
    "inpBlob = cv2.dnn.blobFromImage(image1, 1.0 / 255, (inWidth, inHeight),\n",
    "                          (0, 0, 0), swapRB=False, crop=False)\n",
    "\n",
    "net.setInput(inpBlob)\n",
    "output = net.forward()\n",
    "print(\"Time Taken = {}\".format(time.time() - t))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Slice a probability map ( for e.g Nose ) from the output for a specific keypoint and plot the heatmap ( after resizing ) on the image itself"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "i = 0\n",
    "probMap = output[0, i, :, :]\n",
    "probMap = cv2.resize(probMap, (frameWidth, frameHeight))\n",
    "plt.figure(figsize=[14,10])\n",
    "plt.imshow(cv2.cvtColor(image1, cv2.COLOR_BGR2RGB))\n",
    "plt.imshow(probMap, alpha=0.6)\n",
    "plt.colorbar()\n",
    "plt.axis(\"off\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "detected_keypoints = []\n",
    "keypoints_list = np.zeros((0,3))\n",
    "keypoint_id = 0\n",
    "threshold = 0.1\n",
    "\n",
    "for part in range(nPoints):\n",
    "    probMap = output[0,part,:,:]\n",
    "    probMap = cv2.resize(probMap, (image1.shape[1], image1.shape[0]))\n",
    "#     plt.figure()\n",
    "#     plt.imshow(255*np.uint8(probMap>threshold))\n",
    "    keypoints = getKeypoints(probMap, threshold)\n",
    "    print(\"Keypoints - {} : {}\".format(keypointsMapping[part], keypoints))\n",
    "    keypoints_with_id = []\n",
    "    for i in range(len(keypoints)):\n",
    "        keypoints_with_id.append(keypoints[i] + (keypoint_id,))\n",
    "        keypoints_list = np.vstack([keypoints_list, keypoints[i]])\n",
    "        keypoint_id += 1\n",
    "\n",
    "    detected_keypoints.append(keypoints_with_id)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "frameClone = image1.copy()\n",
    "for i in range(nPoints):\n",
    "    for j in range(len(detected_keypoints[i])):\n",
    "        cv2.circle(frameClone, detected_keypoints[i][j][0:2], 3, [0,0,255], -1, cv2.LINE_AA)\n",
    "plt.figure(figsize=[15,15])\n",
    "plt.imshow(frameClone[:,:,[2,1,0]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "valid_pairs, invalid_pairs = getValidPairs(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "personwiseKeypoints = getPersonwiseKeypoints(valid_pairs, invalid_pairs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(17):\n",
    "    for n in range(len(personwiseKeypoints)):\n",
    "        index = personwiseKeypoints[n][np.array(POSE_PAIRS[i])]\n",
    "        if -1 in index:\n",
    "            continue\n",
    "        B = np.int32(keypoints_list[index.astype(int), 0])\n",
    "        A = np.int32(keypoints_list[index.astype(int), 1])\n",
    "        cv2.line(frameClone, (B[0], A[0]), (B[1], A[1]), colors[i], 3, cv2.LINE_AA)\n",
    "        \n",
    "plt.figure(figsize=[15,15])\n",
    "plt.imshow(frameClone[:,:,[2,1,0]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
